{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Solve Single Crosswalk using QMDP and SARSOP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Method definition info(Any...) in module Base at util.jl:532 overwritten in module Logging at /mnt/c/Users/Maxime/wsl/.julia/v0.6/Logging/src/Logging.jl:115.\n",
      "WARNING: Method definition warn(Any...) in module Base at util.jl:585 overwritten in module Logging at /mnt/c/Users/Maxime/wsl/.julia/v0.6/Logging/src/Logging.jl:115.\n"
     ]
    }
   ],
   "source": [
    "using POMDPs, StatsBase, POMDPToolbox, RLInterface, Parameters, GridInterpolations\n",
    "using AutomotiveDrivingModels,AutoViz\n",
    "using Reel \n",
    "using QMDP, SARSOP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Method definition copyto!(Array{Float64, 1}, AutomotiveDrivingModels.LatLonAccel) in module AutomotiveDrivingModels at /mnt/c/Users/Maxime/wsl/.julia/v0.6/AutomotiveDrivingModels/src/2d/actions/lat_lon_accel.jl:13 overwritten in module AutoUrban at /mnt/c/Users/Maxime/wsl/.julia/v0.6/AutoUrban/src/simulation/actions.jl:10.\n"
     ]
    }
   ],
   "source": [
    "using AutomotivePOMDPs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = MersenneTwister(1);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Crosswalk environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Method definition (::Type{AutomotivePOMDPs.CrosswalkEnv})() in module AutomotivePOMDPs at /home/maxime/Maxime/OneDrive - Leland Stanford Junior University/Research/AutomotivePOMDPs/src/envs/occluded_crosswalk_env.jl:56 overwritten at /home/maxime/Maxime/OneDrive - Leland Stanford Junior University/Research/AutomotivePOMDPs/src/envs/occluded_crosswalk_env.jl:56.\n",
      "WARNING: Method definition (::Type{AutomotivePOMDPs.CrosswalkEnv})(AutomotivePOMDPs.CrosswalkParams) in module AutomotivePOMDPs at /home/maxime/Maxime/OneDrive - Leland Stanford Junior University/Research/AutomotivePOMDPs/src/envs/occluded_crosswalk_env.jl:56 overwritten at /home/maxime/Maxime/OneDrive - Leland Stanford Junior University/Research/AutomotivePOMDPs/src/envs/occluded_crosswalk_env.jl:56.\n",
      "06-Aug 09:25:38:WARNING:root:replacing docs for 'AutomotivePOMDPs.CrosswalkEnv :: Union{Tuple{AutomotivePOMDPs.CrosswalkParams}, Tuple{}}' in module 'AutomotivePOMDPs'.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "AutomotivePOMDPs.SingleOCPOMDP(AutomotivePOMDPs.CrosswalkEnv(Roadway, AutomotiveDrivingModels.Lane(LaneTag(2, 1), AutomotiveDrivingModels.CurvePt[CurvePt({25.000, -10.000, 1.571}, 0.000, 0.000, NaN), CurvePt({25.000, 10.000, 1.571}, 20.000, 0.000, NaN)], 6.0, AutomotiveDrivingModels.SpeedLimit(-Inf, Inf), AutomotiveDrivingModels.LaneBoundary(:unknown, :unknown), AutomotiveDrivingModels.LaneBoundary(:unknown, :unknown), AutomotiveDrivingModels.LaneConnection[], AutomotiveDrivingModels.LaneConnection[]), AutomotiveDrivingModels.ConvexPolygon[ConvexPolygon: len 4 (max 4 pts)\n",
       "\tVecE2(15.000, -1.500)\n",
       "\tVecE2(15.000, -4.500)\n",
       "\tVecE2(21.500, -4.500)\n",
       "\tVecE2(21.500, -1.500)\n",
       "], AutomotivePOMDPs.CrosswalkParams(2, 50.0, 3.0, 20.0, 6.0, 5.0, 37.0, 8.0, 100, 0.5, 2.0, 10.0)), VehicleDef(CAR, 4.000, 1.800), VehicleDef(PEDESTRIAN, 1.000, 1.000), 2.0, 1.0, 1.0, 5.0, -5.0, 37.0, 5.0, 0.5, 0.3, 0.3, false, 1.0, 1.0, 1.0, -1.0, 0.0, 1.0, 0.95)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pomdp = SingleOCPOMDP(ΔT = 0.5, p_birth = 0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy = solve(RandomSolver(rng), pomdp)\n",
    "up = updater(policy);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  3.222652 seconds (1.54 M allocations: 73.924 MiB, 1.47% gc time)\n"
     ]
    }
   ],
   "source": [
    "hr = HistoryRecorder(rng=rng, max_steps=200)\n",
    "@time hist = simulate(hr, pomdp, policy, up);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<video autoplay controls><source src=\"files/reel-54739127818707822.webm?8672762950185310570\" type=\"video/webm\"></video>"
      ],
      "text/plain": [
       "Reel.Frames{MIME{Symbol(\"image/png\")}}(\"/tmp/tmpHz9206\", 0x0000000000000042, 2.0, nothing)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "duration, fps, render_hist = animate_hist(pomdp, hist, SceneOverlay[])\n",
    "film = roll(render_hist, fps = fps, duration = duration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Solve using QMDP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "QMDP.QMDPSolver(100, 0.001, true)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "solver = QMDPSolver(max_iterations=100, tolerance=1e-3, verbose=true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iteration 1   ] residual:          1 | iteration runtime:   2485.151 ms, (      2.49 s total)\n",
      "[Iteration 2   ] residual:       0.95 | iteration runtime:   2570.471 ms, (      5.06 s total)\n",
      "[Iteration 3   ] residual:      0.903 | iteration runtime:   2674.629 ms, (      7.73 s total)\n",
      "[Iteration 4   ] residual:      0.857 | iteration runtime:   2423.629 ms, (      10.2 s total)\n",
      "[Iteration 5   ] residual:      0.811 | iteration runtime:   2424.142 ms, (      12.6 s total)\n",
      "[Iteration 6   ] residual:      0.762 | iteration runtime:   2549.645 ms, (      15.1 s total)\n",
      "[Iteration 7   ] residual:      0.707 | iteration runtime:   2928.046 ms, (      18.1 s total)\n",
      "[Iteration 8   ] residual:      0.649 | iteration runtime:   2481.565 ms, (      20.5 s total)\n",
      "[Iteration 9   ] residual:      0.539 | iteration runtime:   2532.580 ms, (      23.1 s total)\n",
      "[Iteration 10  ] residual:      0.472 | iteration runtime:   3136.414 ms, (      26.2 s total)\n",
      "[Iteration 11  ] residual:       0.42 | iteration runtime:   3064.089 ms, (      29.3 s total)\n",
      "[Iteration 12  ] residual:      0.384 | iteration runtime:   2468.989 ms, (      31.7 s total)\n",
      "[Iteration 13  ] residual:      0.177 | iteration runtime:   2378.109 ms, (      34.1 s total)\n",
      "[Iteration 14  ] residual:      0.129 | iteration runtime:   2530.340 ms, (      36.6 s total)\n",
      "[Iteration 15  ] residual:     0.0995 | iteration runtime:   2629.531 ms, (      39.3 s total)\n",
      "[Iteration 16  ] residual:      0.085 | iteration runtime:   2415.731 ms, (      41.7 s total)\n",
      "[Iteration 17  ] residual:     0.0589 | iteration runtime:   2564.706 ms, (      44.3 s total)\n",
      "[Iteration 18  ] residual:     0.0357 | iteration runtime:   2809.825 ms, (      47.1 s total)\n",
      "[Iteration 19  ] residual:     0.0198 | iteration runtime:   2716.602 ms, (      49.8 s total)\n",
      "[Iteration 20  ] residual:     0.0102 | iteration runtime:   2992.795 ms, (      52.8 s total)\n",
      "[Iteration 21  ] residual:    0.00499 | iteration runtime:   2730.547 ms, (      55.5 s total)\n",
      "[Iteration 22  ] residual:    0.00234 | iteration runtime:   2889.679 ms, (      58.4 s total)\n",
      "[Iteration 23  ] residual:     0.0011 | iteration runtime:   2635.757 ms, (        61 s total)\n",
      "[Iteration 24  ] residual:   0.000503 | iteration runtime:   3490.735 ms, (      64.5 s total)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "POMDPToolbox.AlphaVectorPolicy{AutomotivePOMDPs.SingleOCPOMDP,AutomotivePOMDPs.SingleOCAction}(AutomotivePOMDPs.SingleOCPOMDP(AutomotivePOMDPs.CrosswalkEnv(Roadway, AutomotiveDrivingModels.Lane(LaneTag(2, 1), AutomotiveDrivingModels.CurvePt[CurvePt({25.000, -10.000, 1.571}, 0.000, 0.000, NaN), CurvePt({25.000, 10.000, 1.571}, 20.000, 0.000, NaN)], 6.0, AutomotiveDrivingModels.SpeedLimit(-Inf, Inf), AutomotiveDrivingModels.LaneBoundary(:unknown, :unknown), AutomotiveDrivingModels.LaneBoundary(:unknown, :unknown), AutomotiveDrivingModels.LaneConnection[], AutomotiveDrivingModels.LaneConnection[]), AutomotiveDrivingModels.ConvexPolygon[ConvexPolygon: len 4 (max 4 pts)\n",
       "\tVecE2(15.000, -1.500)\n",
       "\tVecE2(15.000, -4.500)\n",
       "\tVecE2(21.500, -4.500)\n",
       "\tVecE2(21.500, -1.500)\n",
       "], AutomotivePOMDPs.CrosswalkParams(2, 50.0, 3.0, 20.0, 6.0, 5.0, 37.0, 8.0, 100, 0.5, 2.0, 10.0)), VehicleDef(CAR, 4.000, 1.800), VehicleDef(PEDESTRIAN, 1.000, 1.000), 2.0, 1.0, 1.0, 5.0, -5.0, 37.0, 5.0, 0.5, 0.3, 0.3, false, 1.0, 1.0, 1.0, -1.0, 0.0, 1.0, 0.95), Array{Float64,1}[[0.445726, 0.478471, 0.478471, 0.478471, 0.50807, 0.50807, 0.50807, 0.526613, 0.526613, 0.526613  …  1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.445726, 0.478471, 0.478471, 0.478471, 0.50807, 0.50807, 0.50807, 0.526613, 0.526613, 0.526613  …  1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.445726, 0.478471, 0.478471, 0.478471, 0.50807, 0.50807, 0.50807, 0.526613, 0.526613, 0.526613  …  1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.449222, 0.479646, 0.479646, 0.479646, 0.515838, 0.515838, 0.515838, 0.544343, 0.544343, 0.544343  …  1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], AutomotivePOMDPs.SingleOCAction[AutomotivePOMDPs.SingleOCAction(-4.0), AutomotivePOMDPs.SingleOCAction(-2.0), AutomotivePOMDPs.SingleOCAction(0.0), AutomotivePOMDPs.SingleOCAction(2.0)])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "qmdp_policy = solve(solver, pomdp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "up = SingleOCUpdater(pomdp);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 45.029022 seconds (658.04 M allocations: 17.745 GiB, 16.49% gc time)\n"
     ]
    }
   ],
   "source": [
    "hr = HistoryRecorder(rng=rng)\n",
    "@time hist = simulate(hr, pomdp, qmdp_policy, up);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<video autoplay controls><source src=\"files/reel-7440889548517793390.webm?17756073931288784939\" type=\"video/webm\"></video>"
      ],
      "text/plain": [
       "Reel.Frames{MIME{Symbol(\"image/png\")}}(\"/tmp/tmpRckSFs\", 0x0000000000000013, 2.0, nothing)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "duration, fps, render_hist = animate_hist(pomdp, hist)\n",
    "film = roll(render_hist, fps = fps, duration = duration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## QMDP Utility Fusion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "const DecBelief = Dict{Int64, SingleOCDistribution}\n",
    "const DecState = Dict{Int64, SingleOCState}\n",
    "const DecObs = Dict{Int64, SingleOCObs}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pomdp = OCPOMDP()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "up = MixedUpdater(pomdp, pomdp.ΔT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "hr = HistoryRecorder(rng=rng)\n",
    "s0 = initialstate(pomdp, rng)\n",
    "@time hist = simulate(hr, pomdp, qmdp_policy, up, d0, s0);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Solve using SARSOP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SARSOP.SARSOPSolver(Dict{AbstractString,Any}(Pair{AbstractString,Any}(\"randomization\", \"\"),Pair{AbstractString,Any}(\"timeout\", 600.0)))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "solver = SARSOPSolver(randomization=true, precision = 0.001, timeout = 600.)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating a pomdpx file: crosswalk.pomdpx\n",
      "\n",
      "Loading the model ...\n",
      "  input file   : crosswalk.pomdpx\n",
      "  loading time : 148.62s \n",
      "\n",
      "SARSOP initializing ...\n",
      "  initialization time : 178.54s\n",
      "\n",
      "-------------------------------------------------------------------------------\n",
      " Time   |#Trial |#Backup |LBound    |UBound    |Precision  |#Alphas |#Beliefs  \n",
      "-------------------------------------------------------------------------------\n",
      " 178.544 0       0        12.1169    12.1191    0.00217179  4        1        \n",
      " 208.314 4       19       12.1182    12.1191    0.000954115 17       11       \n",
      "-------------------------------------------------------------------------------\n",
      "\n",
      "SARSOP finishing ...\n",
      "  target precision reached\n",
      "  target precision  : 0.001000\n",
      "  precision reached : 0.000954 \n",
      "\n",
      "-------------------------------------------------------------------------------\n",
      " Time   |#Trial |#Backup |LBound    |UBound    |Precision  |#Alphas |#Beliefs  \n",
      "-------------------------------------------------------------------------------\n",
      " 208.315 4       19       12.1182    12.1191    0.000954115 17       11       \n",
      "-------------------------------------------------------------------------------\n",
      "\n",
      "Writing out policy ...\n",
      "  output file : crosswalk.policy\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "SARSOP.POMDPPolicy(\"crosswalk.policy\", POMDPXFiles.POMDPAlphas([10.2753 9.80559 … 10.2753 10.2753; 10.3449 9.93102 … 10.3449 10.3449; … ; 19.9998 19.9998 … 19.9999 19.9999; 19.9998 19.9998 … 19.9999 19.9999], [3, 1, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3]), AutomotivePOMDPs.SingleOCPOMDP(AutomotivePOMDPs.CrosswalkEnv(Roadway, AutomotiveDrivingModels.Lane(LaneTag(2, 1), AutomotiveDrivingModels.CurvePt[CurvePt({25.000, -10.000, 1.571}, 0.000, 0.000, NaN), CurvePt({25.000, 10.000, 1.571}, 20.000, 0.000, NaN)], 6.0, AutomotiveDrivingModels.SpeedLimit(-Inf, Inf), AutomotiveDrivingModels.LaneBoundary(:unknown, :unknown), AutomotiveDrivingModels.LaneBoundary(:unknown, :unknown), AutomotiveDrivingModels.LaneConnection[], AutomotiveDrivingModels.LaneConnection[]), AutomotiveDrivingModels.ConvexPolygon[ConvexPolygon: len 4 (max 4 pts)\n",
       "\tVecE2(15.000, -1.500)\n",
       "\tVecE2(15.000, -4.500)\n",
       "\tVecE2(21.500, -4.500)\n",
       "\tVecE2(21.500, -1.500)\n",
       "], AutomotivePOMDPs.CrosswalkParams(2, 50.0, 3.0, 20.0, 6.0, 5.0, 37.0, 8.0, 100, 0.5, 2.0, 10.0)), VehicleDef(CAR, 4.000, 1.800), VehicleDef(PEDESTRIAN, 1.000, 1.000), 2.0, 1.0, 1.0, 5.0, -5.0, 37.0, 5.0, 0.5, 0.3, 0.3, false, 1.0, 1.0, 1.0, -1.0, 0.0, 1.0, 0.95), Any[AutomotivePOMDPs.SingleOCAction(-4.0), AutomotivePOMDPs.SingleOCAction(-2.0), AutomotivePOMDPs.SingleOCAction(0.0), AutomotivePOMDPs.SingleOCAction(2.0)])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sarsop_policy = solve(solver, pomdp,\n",
    "               SARSOP.create_policy(solver, pomdp, \"crosswalk.policy\"), pomdp_file_name=\"crosswalk.pomdpx\" )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 11.688374 seconds (165.43 M allocations: 13.435 GiB, 16.18% gc time)\n"
     ]
    }
   ],
   "source": [
    "hr = HistoryRecorder(rng=rng)\n",
    "@time hist = simulate(hr, pomdp, sarsop_policy, DiscreteUpdater(pomdp));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<video autoplay controls><source src=\"files/reel-11739988796639125612.webm?5755122942129870572\" type=\"video/webm\"></video>"
      ],
      "text/plain": [
       "Reel.Frames{MIME{Symbol(\"image/png\")}}(\"/tmp/tmpoBsqbv\", 0x0000000000000009, 2.0, nothing)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "duration, fps, render_hist = animate_hist(pomdp, hist)\n",
    "film = roll(render_hist, fps = fps, duration = duration)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.6.1",
   "language": "julia",
   "name": "julia-0.6"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
